import random
from typing import Any, List, Dict
from collections import defaultdict
from pipelines.prompta.utils import tuple2word, word2tuple
from prompta.core.language import BaseLanguage
from .probabilistic_abstract_oracle import ProbabilisticAbstractOracle


class ProbabilisticSimpleOracle(ProbabilisticAbstractOracle):

    def __init__(self, language: BaseLanguage, prob: float = 0.1, *args: Any, **kwargs: Any) -> None:
        super().__init__(language, prob, *args, **kwargs)
        self.reproducable = kwargs.get('reproducable', False)
        if self.reproducable:
            self.rnd = 0
            if self.prob > 0:
                self.rnd_lim = int(1 / self.prob)
            else:
                self.rnd_lim = int(1e+10)

    def _get_membership_query_result(self, input_str: str, *args: Any, **kwargs: Any) -> Any:
        res = self.language.in_language(input_str)

        if self.reproducable:
            if (self.rnd + 1) % self.rnd_lim == 0:
                res = not res
            self.rnd += 1
        else:
            r = random.random()
            if r < self.prob:
                res = not res

        return res
